{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "797a44a6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-04T11:28:32.762244Z",
     "start_time": "2023-11-04T11:28:32.756743Z"
    }
   },
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCausalLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d7bf1d78",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-04T11:28:33.692354Z",
     "start_time": "2023-11-04T11:28:33.685691Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"http_proxy\"] = \"http://127.0.0.1:7890\"\n",
    "os.environ[\"https_proxy\"] = \"http://127.0.0.1:7890\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d20935e7",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-04T11:28:41.433170Z",
     "start_time": "2023-11-04T11:28:34.782166Z"
    }
   },
   "outputs": [],
   "source": [
    "gpt2 = AutoModelForCausalLM.from_pretrained('gpt2')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5252299",
   "metadata": {},
   "source": [
    "```\n",
    "# modeling_utils._load_state_dict_into_model\n",
    "def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):\n",
    "    ...\n",
    "```\n",
    "\n",
    "- lm_head 结构是在 `GPT2LMHeadModel` 内部定义和创建的，\n",
    "    - 但在 from_pretrained 加载预训练参数时，其参数是从 gpt2model（一个Transformer架构）的 wte 里来的；"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3a0e2506",
   "metadata": {},
   "source": [
    "## model arch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "b78acff7",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-04T12:13:46.758547Z",
     "start_time": "2023-11-04T12:13:46.745824Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GPT2LMHeadModel(\n",
       "  (transformer): GPT2Model(\n",
       "    (wte): Embedding(50257, 768)\n",
       "    (wpe): Embedding(1024, 768)\n",
       "    (drop): Dropout(p=0.1, inplace=False)\n",
       "    (h): ModuleList(\n",
       "      (0-11): 12 x GPT2Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): GPT2Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): GPT2MLP(\n",
       "          (c_fc): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (act): NewGELUActivation()\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "  )\n",
       "  (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gpt2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "795fd467",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-04T11:28:44.481171Z",
     "start_time": "2023-11-04T11:28:44.465463Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "type(gpt2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "769becca",
   "metadata": {},
   "source": [
    "### output_embeddings()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "136461b7",
   "metadata": {},
   "source": [
    "\n",
    "- input_embeddings: wte\n",
    "    - Word Token Embeddings (ps. wpe, Word Position Embeddings )\n",
    "- output_embeddings: lm_head\n",
    "\n",
    "```\n",
    "class GPT2LMHeadModel(GPT2PreTrainedModel):\n",
    "    _tied_weights_keys = [\"lm_head.weight\"]\n",
    "    \n",
    "    ...\n",
    "    \n",
    "    def get_input_embeddings(self):\n",
    "        return self.wte\n",
    "        \n",
    "    def get_output_embeddings(self):\n",
    "        return self.lm_head\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ea0846ad",
   "metadata": {},
   "source": [
    "## word embedding & lm_head"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb572b87",
   "metadata": {},
   "source": [
    "- The GPT2 Model transformer with a **language modeling head on top** (`lm_head`) (linear layer with weights **tied** to the input embeddings).\n",
    "    - https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2LMHeadModel\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9d077c55",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-04T11:28:50.382174Z",
     "start_time": "2023-11-04T11:28:50.355988Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([[-0.1101, -0.0393,  0.0331,  ..., -0.1364,  0.0151,  0.0453],\n",
       "        [ 0.0403, -0.0486,  0.0462,  ...,  0.0861,  0.0025,  0.0432],\n",
       "        [-0.1275,  0.0479,  0.1841,  ...,  0.0899, -0.1297, -0.0879],\n",
       "        ...,\n",
       "        [-0.0445, -0.0548,  0.0123,  ...,  0.1044,  0.0978, -0.0695],\n",
       "        [ 0.1860,  0.0167,  0.0461,  ..., -0.0963,  0.0785, -0.0225],\n",
       "        [ 0.0514, -0.0277,  0.0499,  ...,  0.0070,  0.1552,  0.1207]],\n",
       "       requires_grad=True)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gpt2.transformer.wte.weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d6218c6f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-04T11:28:52.225372Z",
     "start_time": "2023-11-04T11:28:52.212121Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([[-0.1101, -0.0393,  0.0331,  ..., -0.1364,  0.0151,  0.0453],\n",
       "        [ 0.0403, -0.0486,  0.0462,  ...,  0.0861,  0.0025,  0.0432],\n",
       "        [-0.1275,  0.0479,  0.1841,  ...,  0.0899, -0.1297, -0.0879],\n",
       "        ...,\n",
       "        [-0.0445, -0.0548,  0.0123,  ...,  0.1044,  0.0978, -0.0695],\n",
       "        [ 0.1860,  0.0167,  0.0461,  ..., -0.0963,  0.0785, -0.0225],\n",
       "        [ 0.0514, -0.0277,  0.0499,  ...,  0.0070,  0.1552,  0.1207]],\n",
       "       requires_grad=True)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gpt2.lm_head.weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4e7e7c9a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-04T11:28:54.026387Z",
     "start_time": "2023-11-04T11:28:54.016216Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.nn.parameter.Parameter"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "type(gpt2.lm_head.weight)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "46a80a10",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-04T11:28:55.652053Z",
     "start_time": "2023-11-04T11:28:55.642734Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gpt2.lm_head.weight is gpt2.transformer.wte.weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "0f7ae7b2",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-04T11:29:21.134345Z",
     "start_time": "2023-11-04T11:29:21.125352Z"
    }
   },
   "outputs": [],
   "source": [
    "gpt2.lm_head.weight.data_ptr??"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "51280b6f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-04T11:29:31.028997Z",
     "start_time": "2023-11-04T11:29:31.019353Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Returns the address of the first element of :attr:`self` tensor.\n",
    "gpt2.lm_head.weight.data_ptr() == gpt2.transformer.wte.weight.data_ptr()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "ff0f5f62",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-04T12:15:49.608058Z",
     "start_time": "2023-11-04T12:15:49.594716Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([50257, 768])\n",
      "torch.Size([50257, 768])\n"
     ]
    }
   ],
   "source": [
    "print(gpt2.state_dict()['transformer.wte.weight'].shape)\n",
    "print(gpt2.state_dict()['lm_head.weight'].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "ef8f3907",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-04T11:30:10.072683Z",
     "start_time": "2023-11-04T11:30:10.052655Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-0.1101, -0.0393,  0.0331,  ..., -0.1364,  0.0151,  0.0453],\n",
      "        [ 0.0403, -0.0486,  0.0462,  ...,  0.0861,  0.0025,  0.0432],\n",
      "        [-0.1275,  0.0479,  0.1841,  ...,  0.0899, -0.1297, -0.0879],\n",
      "        ...,\n",
      "        [-0.0445, -0.0548,  0.0123,  ...,  0.1044,  0.0978, -0.0695],\n",
      "        [ 0.1860,  0.0167,  0.0461,  ..., -0.0963,  0.0785, -0.0225],\n",
      "        [ 0.0514, -0.0277,  0.0499,  ...,  0.0070,  0.1552,  0.1207]])\n",
      "tensor([[-0.1101, -0.0393,  0.0331,  ..., -0.1364,  0.0151,  0.0453],\n",
      "        [ 0.0403, -0.0486,  0.0462,  ...,  0.0861,  0.0025,  0.0432],\n",
      "        [-0.1275,  0.0479,  0.1841,  ...,  0.0899, -0.1297, -0.0879],\n",
      "        ...,\n",
      "        [-0.0445, -0.0548,  0.0123,  ...,  0.1044,  0.0978, -0.0695],\n",
      "        [ 0.1860,  0.0167,  0.0461,  ..., -0.0963,  0.0785, -0.0225],\n",
      "        [ 0.0514, -0.0277,  0.0499,  ...,  0.0070,  0.1552,  0.1207]])\n"
     ]
    }
   ],
   "source": [
    "# 只占用一份内存空间\n",
    "print(gpt2.state_dict()['transformer.wte.weight'])\n",
    "print(gpt2.state_dict()['lm_head.weight'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b2f3462",
   "metadata": {},
   "source": [
    "### how tied"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23c0ba13",
   "metadata": {},
   "source": [
    "```\n",
    "\n",
    "class PreTrainedModel(\n",
    "    ...\n",
    "    def tie_weights(self):\n",
    "\n",
    "        if getattr(self.config, \"tie_word_embeddings\", True):\n",
    "            output_embeddings = self.get_output_embeddings()\n",
    "            if output_embeddings is not None:\n",
    "                self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())\n",
    "\n",
    "    def _tie_or_clone_weights(self, output_embeddings, input_embeddings):\n",
    "        \"\"\"Tie or clone module weights depending of whether we are using TorchScript or not\"\"\"\n",
    "        if self.config.torchscript:\n",
    "            output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())\n",
    "        else:\n",
    "            output_embeddings.weight = input_embeddings.weight\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c8a0e4a",
   "metadata": {},
   "source": [
    "### tied or shared tensors"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76b3650f",
   "metadata": {},
   "source": [
    "- https://huggingface.co/docs/safetensors/torch_shared_tensors\n",
    "\n",
    "    - Pytorch uses shared tensors for some computation. This is extremely interesting to reduce memory usage in general.\n",
    "\n",
    "    - One very classic use case is in transformers the embeddings are shared with lm_head. By using the same matrix, the model uses less parameters, and gradients flow much better to the embeddings (which is the start of the model, so they don’t flow easily there, whereas lm_head is at the tail of the model, so gradients are extremely good over there, since they are the same tensors, they both benefit)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "9b9752d8",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-02T16:03:19.585029Z",
     "start_time": "2023-11-02T16:03:19.569927Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "odict_keys(['a.weight', 'a.bias', 'b.weight', 'b.bias'])\n"
     ]
    }
   ],
   "source": [
    "\n",
    "from torch import nn\n",
    "\n",
    "class Model(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.a = nn.Linear(100, 100)\n",
    "        self.b = nn.Linear(100, 100)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.b(self.a(x))\n",
    "\n",
    "\n",
    "model = Model()\n",
    "print(model.state_dict())\n",
    "# odict_keys(['a.weight', 'a.bias', 'b.weight', 'b.bias'])\n",
    "torch.save(model.state_dict(), \"model.bin\")\n",
    "# This file is now 41k instead of ~80k, because A and B are the same weight hence only 1 is saved on disk with both `a` and `b` pointing to the same buffer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "65bb6b61",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-02T16:02:49.658855Z",
     "start_time": "2023-11-02T16:02:49.635945Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OrderedDict([('a.weight', tensor([[ 0.0777,  0.0683,  0.0710,  ...,  0.0692,  0.0377,  0.0951],\n",
      "        [ 0.0986,  0.0207, -0.0691,  ...,  0.0168,  0.0718, -0.0220],\n",
      "        [ 0.0806, -0.0145, -0.0851,  ...,  0.0653,  0.0729, -0.0443],\n",
      "        ...,\n",
      "        [ 0.0818, -0.0725,  0.0595,  ..., -0.0729,  0.0758,  0.0752],\n",
      "        [ 0.0983, -0.0282, -0.0066,  ..., -0.0467, -0.0237, -0.0505],\n",
      "        [-0.0291, -0.0101, -0.0842,  ..., -0.0088, -0.0748,  0.0641]])), ('a.bias', tensor([ 0.0377,  0.0652,  0.0465, -0.0092,  0.0909,  0.0700, -0.0753, -0.0464,\n",
      "        -0.0905, -0.0142,  0.0044,  0.0673, -0.0510,  0.0401,  0.0207, -0.0703,\n",
      "         0.0661, -0.0329, -0.0917,  0.0600,  0.0594,  0.0968,  0.0822, -0.0912,\n",
      "         0.0221,  0.0809, -0.0047, -0.0823,  0.0861,  0.0808, -0.0131, -0.0903,\n",
      "        -0.0515,  0.0507, -0.0054,  0.0317, -0.0846, -0.0964,  0.0124, -0.0123,\n",
      "         0.0576,  0.0543, -0.0357,  0.0272,  0.0058, -0.0178,  0.0899,  0.0117,\n",
      "         0.0805, -0.0146, -0.0219,  0.0898,  0.0014,  0.0348,  0.0245, -0.0595,\n",
      "         0.0766,  0.0728,  0.0643, -0.0987, -0.0003, -0.0414,  0.0767, -0.0187,\n",
      "        -0.0224,  0.0797,  0.0614, -0.0766,  0.0446, -0.0091,  0.0286,  0.0228,\n",
      "         0.0843,  0.0702, -0.0594,  0.0831, -0.0560,  0.0565, -0.0780,  0.0018,\n",
      "         0.0831,  0.0598, -0.0688, -0.0731, -0.0150,  0.0103, -0.0475, -0.0635,\n",
      "        -0.0943,  0.0312, -0.0791,  0.0889,  0.0783,  0.0273,  0.0075,  0.0648,\n",
      "        -0.0479, -0.0622,  0.0339, -0.0003])), ('b.weight', tensor([[ 0.0777,  0.0683,  0.0710,  ...,  0.0692,  0.0377,  0.0951],\n",
      "        [ 0.0986,  0.0207, -0.0691,  ...,  0.0168,  0.0718, -0.0220],\n",
      "        [ 0.0806, -0.0145, -0.0851,  ...,  0.0653,  0.0729, -0.0443],\n",
      "        ...,\n",
      "        [ 0.0818, -0.0725,  0.0595,  ..., -0.0729,  0.0758,  0.0752],\n",
      "        [ 0.0983, -0.0282, -0.0066,  ..., -0.0467, -0.0237, -0.0505],\n",
      "        [-0.0291, -0.0101, -0.0842,  ..., -0.0088, -0.0748,  0.0641]])), ('b.bias', tensor([ 0.0377,  0.0652,  0.0465, -0.0092,  0.0909,  0.0700, -0.0753, -0.0464,\n",
      "        -0.0905, -0.0142,  0.0044,  0.0673, -0.0510,  0.0401,  0.0207, -0.0703,\n",
      "         0.0661, -0.0329, -0.0917,  0.0600,  0.0594,  0.0968,  0.0822, -0.0912,\n",
      "         0.0221,  0.0809, -0.0047, -0.0823,  0.0861,  0.0808, -0.0131, -0.0903,\n",
      "        -0.0515,  0.0507, -0.0054,  0.0317, -0.0846, -0.0964,  0.0124, -0.0123,\n",
      "         0.0576,  0.0543, -0.0357,  0.0272,  0.0058, -0.0178,  0.0899,  0.0117,\n",
      "         0.0805, -0.0146, -0.0219,  0.0898,  0.0014,  0.0348,  0.0245, -0.0595,\n",
      "         0.0766,  0.0728,  0.0643, -0.0987, -0.0003, -0.0414,  0.0767, -0.0187,\n",
      "        -0.0224,  0.0797,  0.0614, -0.0766,  0.0446, -0.0091,  0.0286,  0.0228,\n",
      "         0.0843,  0.0702, -0.0594,  0.0831, -0.0560,  0.0565, -0.0780,  0.0018,\n",
      "         0.0831,  0.0598, -0.0688, -0.0731, -0.0150,  0.0103, -0.0475, -0.0635,\n",
      "        -0.0943,  0.0312, -0.0791,  0.0889,  0.0783,  0.0273,  0.0075,  0.0648,\n",
      "        -0.0479, -0.0622,  0.0339, -0.0003]))])\n"
     ]
    }
   ],
   "source": [
    "class Model(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.a = nn.Linear(100, 100)\n",
    "        # tied/share\n",
    "        self.b = self.a\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.b(self.a(x))\n",
    "\n",
    "\n",
    "model = Model()\n",
    "print(model.state_dict())\n",
    "# odict_keys(['a.weight', 'a.bias', 'b.weighat', 'b.bias'])\n",
    "torch.save(model.state_dict(), \"model2.bin\")\n",
    "# This file is now 41k instead of ~80k, because A and B are the same weight hence only 1 is saved on disk with both `a` and `b` pointing to the same buffer\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
